#' @title Survival Bayesian Additive Regression Trees Learner
#' @author bblodfon
#' @name mlr_learners_surv.bart
#'
#' @description
#' Fits a Bayesian Additive Regression Trees (BART) learner to right-censored
#' survival data. Calls [BART::mc.surv.bart()] from \CRANpkg{BART}.
#'
#' @section Prediction types:
#' This learner returns two prediction types:
#' 1. `distr`: a 3d survival array with observations as 1st dimension, time
#' points as 2nd and the posterior draws as 3rd dimension.
#' Calculated using the internal `predict.survbart()` function.
#' 2. `crank`: the expected mortality using [mlr3proba::.surv_return()]. The parameter
#' `which.curve` decides which posterior draw (3rd dimension) will be used for the
#' calculation of the expected mortality. Note that the median posterior is
#' by default used for the calculation of survival measures that require a `distr`
#' prediction, see more info on [PredictionSurv][mlr3proba::PredictionSurv].
#'
#' @section Initial parameter values:
#' - `mc.cores` is initialized to 1 to avoid threading conflicts with \CRANpkg{future}.
#'
#' @section Custom mlr3 parameters:
#' - `quiet` allows to suppress messages generated by the wrapped C++ code. Is
#' initialized to `TRUE`.
#' - `importance` allows to choose the type of importance. Default is `count`,
#' see documentation of method `$importance()` for more details.
#' - `which.curve` allows to choose which posterior draw will be used for the
#' calculation of the `crank` prediction. If between (0,1) it is taken as the
#' quantile of the curves otherwise if greater than 1 it is taken as the curve
#' index, can also be 'mean'. By default the **median posterior** is used,
#' i.e. `which.curve` is 0.5.
#'
#' @templateVar id surv.bart
#' @template learner
#'
#' @references
#' `r format_bib("sparapani2021nonparametric", "chipman2010bart")`
#'
#' @template seealso_learner
#' @template simple_example
#' @export
LearnerSurvLearnerSurvBART = R6Class("LearnerSurvLearnerSurvBART",
  inherit = mlr3proba::LearnerSurv,
  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    initialize = function() {
      param_set = ps(
        K = p_dbl(default = NULL, special_vals = list(NULL), lower = 1, tags = c("train", "predict")),
        events = p_uty(default = NULL, tags = c("train", "predict")),
        ztimes = p_uty(default = NULL, tags = c("train", "predict")),
        zdelta = p_uty(default = NULL, tags = c("train", "predict")),
        sparse = p_lgl(default = FALSE, tags = "train"),
        theta = p_dbl(default = 0, tags = "train"),
        omega = p_dbl(default = 1, tags = "train"),
        a = p_dbl(default = 0.5, lower = 0.5, upper = 1, tags = "train"),
        b = p_dbl(default = 1L, tags = "train"),
        augment = p_lgl(default = FALSE, tags = "train"),
        rho = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"),
        usequants = p_lgl(default = FALSE, tags = "train"),
        rm.const = p_lgl(default = TRUE, tags = "train"),
        type = p_fct(levels = c("pbart", "lbart"), default = "pbart", tags = "train"),
        ntype = p_int(lower = 1, upper = 3, tags = "train"),
        k = p_dbl(default = 2.0, lower = 0, tags = "train"),
        power = p_dbl(default = 2.0, lower = 0, tags = "train"),
        base = p_dbl(default = 0.95, lower = 0, upper = 1, tags = "train"),
        offset = p_dbl(default = NULL, special_vals = list(NULL), tags = "train"),
        ntree = p_int(default = 50L, lower = 1L, tags = "train"),
        numcut = p_int(default = 100L, lower = 1L, tags = "train"),
        ndpost = p_int(default = 1000L, lower = 1L, tags = "train"),
        nskip = p_int(default = 250L, lower = 0L, tags = "train"),
        keepevery = p_int(default = 10L, lower = 1L, tags = "train"),
        printevery = p_int(default = 100L, lower = 1L, tags = "train"),
        seed = p_int(default = 99L, tags = "train"),
        mc.cores = p_int(default = 2L, lower = 1L, tags = c("train", "predict")),
        nice = p_int(default = 19L, lower = 0L, upper = 19L, tags = c("train", "predict")),
        openmp = p_lgl(default = TRUE, tags = "predict"),
        quiet = p_lgl(default = TRUE, tags = "predict"),
        importance = p_fct(default = "count", levels = c("count", "prob"), tags = "train"),
        which.curve = p_dbl(lower = 0L, special_vals = list("mean"), tags = "predict")
      )

      # custom defaults
      param_set$values = list(mc.cores = 1, quiet = TRUE, importance = "count",
                              which.curve = 0.5) # 0.5 quantile => median posterior

      super$initialize(
        id = "surv.bart",
        packages = "BART",
        feature_types = c("logical", "integer", "numeric"),
        predict_types = c("crank", "distr"),
        param_set = param_set,
        properties = c("importance", "missings"),
        man = "mlr3extralearners::mlr_learners_surv.bart",
        label = "Bayesian Additive Regression Trees"
      )
    },

    #' @description
    #' Two types of importance scores are supported based on the value
    #' of the parameter `importance`:
    #' 1. `prob`: The mean selection probability of each feature in the trees,
    #' extracted from the slot `varprob.mean`.
    #' If `sparse = FALSE` (default), this is a fixed constant.
    #' Recommended to use this option when `sparse = TRUE`.
    #' 2. `count`: The mean observed count of each feature in the trees (average
    #' number of times the feature was used in a tree decision rule across all
    #' posterior draws), extracted from the slot `varcount.mean`.
    #' This is the default importance scores.
    #'
    #' In both cases, higher values signify more important variables.
    #'
    #' @return Named `numeric()`.
    importance = function() {
      if (is.null(self$model$model)) {
        stopf("No model stored")
      }

      pars = self$param_set$get_values(tags = "train")

      if (pars$importance == "prob") {
        sort(self$model$model$varprob.mean[-1], decreasing = TRUE)
      } else {
        sort(self$model$model$varcount.mean[-1], decreasing = TRUE)
      }
    }
  ),

  private = list(
    .train = function(task) {
      pars = self$param_set$get_values(tags = "train")
      pars$importance = NULL # not used in the train function

      x.train = as.data.frame(task$data(cols = task$feature_names)) # nolint
      truth = task$truth()
      times = truth[, 1]
      delta = truth[, 2] # delta => status

      .fun = ifelse(.Platform$OS.type == "windows", BART::surv.bart, BART::mc.surv.bart)

      model = invoke(
        .fun,
        x.train = x.train,
        times = times,
        delta = delta,
        .args = pars
      )

      list(
        model = model,
        # need these for predict
        x.train = x.train,
        times = times,
        delta = delta
      )
    },

    .predict = function(task) {
      # get parameters with tag "predict"
      pars = self$param_set$get_values(tags = "predict")

      # get newdata and ensure same ordering in train and predict
      x.test = as.data.frame(ordered_features(task, self)) # nolint

      # subset parameters to use in `surv.pre.bart`
      pars_pre = pars[names(pars) %in% c("K", "events", "ztimes", "zdelta")]

      # transform data to be suitable for BART survival analysis (needs train data)
      trans_data = invoke(
        BART::surv.pre.bart,
        times   = self$model$times,
        delta   = self$model$delta,
        x.train = self$model$x.train,
        x.test  = x.test,
        .args   = pars_pre
      )

      # subset parameters to use in `predict`
      pars_pred = pars[names(pars) %in% c("mc.cores", "nice")]

      pred_fun = function() {
        invoke(
          predict,
          self$model$model,
          newdata = trans_data$tx.test,
          .args = pars_pred
        )
      }

      # don't print C++ generated info during prediction
      if (pars$quiet) {
        utils::capture.output({
          pred = pred_fun()
        })
      } else {
        pred = pred_fun()
      }

      # Number of test observations
      N = task$nrow
      # Number of unique times
      K = pred$K
      times = pred$times
      # Number of posterior draws
      M = nrow(pred$surv.test)

      # Convert full posterior survival matrix to 3D survival array
      # See page 34-35 in Sparapani (2021) for more details
      surv_array = aperm(
        array(pred$surv.test, dim = c(M, K, N), dimnames = list(NULL, times, NULL)),
        c(3, 2, 1)
      )

      # distr => 3d survival array
      # crank => expected mortality
      mlr3proba::.surv_return(times = times, surv = surv_array,
                              which.curve = pars$which.curve)
    }
  )
)

.extralrns_dict$add("surv.bart", LearnerSurvLearnerSurvBART)
